Feature engineering for electricity load forecasting#

The purpose of this notebook is to demonstrate how to use skrub and polars to perform feature engineering for electricity load forecasting.

We will build a set of features from different sources:

  • Historical weather data for 10 medium to large urban areas in France;

  • Holidays and calendar features for France;

  • Historical electricity load data for the whole of France.

All these data sources cover a time range from March 23, 2021 to May 31, 2025.

Since our maximum forecasting horizon is 24 hours, we consider that the future weather data is known at a chosen prediction time. Similarly, the holidays and calendar features are known at prediction time for any point in the future.

Therefore, features derived from the weather and calendar data can be used to engineer “future covariates”. Since the load data is our prediction target, we will can also use it to engineer “past covariates” such as lagged features and rolling aggregations.

Environment setup#

We need to install some extra dependencies for this notebook if needed (when running jupyterlite). We need the development version of skrub to be able to use the skrub expressions.

%pip install -q https://pypi.anaconda.org/ogrisel/simple/polars/1.24.0/polars-1.24.0-cp39-abi3-emscripten_3_1_58_wasm32.whl
%pip install -q altair holidays https://pypi.anaconda.org/ogrisel/simple/skrub/0.6.dev0/skrub-0.6.dev0-py3-none-any.whl
ERROR: polars-1.24.0-cp39-abi3-emscripten_3_1_58_wasm32.whl is not a supported wheel on this platform.

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
# The following 3 imports are only needed to workaround some limitations
# when using polars in a pyodide/jupyterlite notebook.
import tzdata  # noqa: F401
import pandas as pd
from pyarrow.parquet import read_table

import polars as pl
import skrub
from pathlib import Path
import holidays
import warnings

# Ignore warnings from pkg_resources triggered by Python 3.13's multiprocessing.
warnings.filterwarnings("ignore", category=UserWarning, module="pkg_resources")

Time range#

Let’s define a hourly time range from March 23, 2021 to May 31, 2025 that will be used to join the electricity load data and the weather data. The time range is in UTC timezone to avoid any ambiguity when joining with the weather data that is also in UTC.

We wrap the polars dataframe in a skrub variable to benefit from the built-in TableReport display in the notebook. Using the skrub expression system will also be useful later.

time_range_start = pl.datetime(2021, 3, 23, hour=0, time_zone="UTC")
time_range_end = pl.datetime(2025, 5, 31, hour=23, time_zone="UTC")
time = skrub.var(
    "time",
    pl.DataFrame().with_columns(
        pl.datetime_range(
            start=time_range_start,
            end=time_range_end,
            time_zone="UTC",
            interval="1h",
        ).alias("time"),
    ),
)
time
<Var 'time'>
Show graph Var 'time'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

To avoid network issues when running this notebook, the necessary data files have already been downloaded and saved in the datasets folder. See the README.md file for instructions to download the data manually if you want to re-run this notebook with more recent data.

data_source_folder = Path("../datasets")
for data_file in sorted(data_source_folder.iterdir()):
    print(data_file)
../datasets/README.md
../datasets/Total Load - Day Ahead _ Actual_202101010000-202201010000.csv
../datasets/Total Load - Day Ahead _ Actual_202201010000-202301010000.csv
../datasets/Total Load - Day Ahead _ Actual_202301010000-202401010000.csv
../datasets/Total Load - Day Ahead _ Actual_202401010000-202501010000.csv
../datasets/Total Load - Day Ahead _ Actual_202501010000-202601010000.csv
../datasets/weather_bayonne.parquet
../datasets/weather_brest.parquet
../datasets/weather_lille.parquet
../datasets/weather_limoges.parquet
../datasets/weather_lyon.parquet
../datasets/weather_marseille.parquet
../datasets/weather_nantes.parquet
../datasets/weather_paris.parquet
../datasets/weather_strasbourg.parquet
../datasets/weather_toulouse.parquet

List of 10 medium to large urban areas to approximately cover most regions in France with a slight focus on most populated regions that are likely to drive electricity demand.

city_names = [
    "paris",
    "lyon",
    "marseille",
    "toulouse",
    "lille",
    "limoges",
    "nantes",
    "strasbourg",
    "brest",
    "bayonne",
]
all_city_weather_raw = {}
for city_name in city_names:
    # all_city_weather_raw[city_name] = skrub.var(
    # f"{city_name}_weather_raw",
    all_city_weather_raw[city_name] = (
        pl.from_arrow(read_table(f"../datasets/weather_{city_name}.parquet"))
    ).with_columns(
        [
            pl.col("time").dt.cast_time_unit(
                "us"
            ),  # Ensure time column has the same type
        ]
    )
all_city_weather_raw["brest"]
shape: (38_688, 7)
timetemperature_2mprecipitationwind_speed_10mcloud_coversoil_moisture_1_to_3cmrelative_humidity_2m
datetime[μs, UTC]f32f32f32f32f32f32
2021-01-01 00:00:00 UTCnullnullnullnullnullnull
2021-01-01 01:00:00 UTCnullnullnullnullnullnull
2021-01-01 02:00:00 UTCnullnullnullnullnullnull
2021-01-01 03:00:00 UTCnullnullnullnullnullnull
2021-01-01 04:00:00 UTCnullnullnullnullnullnull
2025-05-31 19:00:00 UTC17.51750.012.06945.00.16873.0
2025-05-31 20:00:00 UTC16.26750.09.11447199.00.16877.0
2025-05-31 21:00:00 UTC15.51750.07.55999993.00.16984.0
2025-05-31 22:00:00 UTC15.56750.09.0100.00.1782.0
2025-05-31 23:00:00 UTC15.56750.05.506941100.00.17181.0
all_city_weather_raw["brest"].drop_nulls(subset=["temperature_2m"])
shape: (36_744, 7)
timetemperature_2mprecipitationwind_speed_10mcloud_coversoil_moisture_1_to_3cmrelative_humidity_2m
datetime[μs, UTC]f32f32f32f32f32f32
2021-03-23 00:00:00 UTC4.628null10.086427nullnull94.0
2021-03-23 01:00:00 UTC5.0280.011.1832016.0null95.0
2021-03-23 02:00:00 UTC5.0780.010.9667136.0null94.0
2021-03-23 03:00:00 UTC4.6280.010.4647975.0null93.0
2021-03-23 04:00:00 UTC4.4280.010.4647975.0null92.0
2025-05-31 19:00:00 UTC17.51750.012.06945.00.16873.0
2025-05-31 20:00:00 UTC16.26750.09.11447199.00.16877.0
2025-05-31 21:00:00 UTC15.51750.07.55999993.00.16984.0
2025-05-31 22:00:00 UTC15.56750.09.0100.00.1782.0
2025-05-31 23:00:00 UTC15.56750.05.506941100.00.17181.0
all_city_weather = time.skb.eval()
for city_name, city_weather_raw in all_city_weather_raw.items():
    all_city_weather = all_city_weather.join(
        city_weather_raw.rename(
            lambda x: x if x == "time" else "weather_" + x + "_" + city_name
        ),
        on="time",
        how="inner",
    )

all_city_weather = skrub.var(
    "all_city_weather",
    all_city_weather,
)
all_city_weather
<Var 'all_city_weather'>
Show graph Var 'all_city_weather'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

Calendar and holidays features#

We leverage the holidays package to enrich the time range with some calendar features such as public holidays in France. We also add some features that are useful for time series forecasting such as the day of the week, the day of the year, and the hour of the day.

Note that the holidays package requires us to extract the date for the French timezone.

Similarly for the calendar features: all the time features are extracted from the time in the French timezone.

holidays_fr = holidays.country_holidays("FR", years=range(2019, 2026))

fr_time = pl.col("time").dt.convert_time_zone("Europe/Paris")
calendar = time.with_columns(
    [
        fr_time.dt.hour().alias("cal_hour_of_day"),
        fr_time.dt.weekday().alias("cal_day_of_week"),
        fr_time.dt.ordinal_day().alias("cal_day_of_year"),
        fr_time.dt.year().alias("cal_year"),
        fr_time.dt.date().is_in(holidays_fr.keys()).alias("cal_is_holiday"),
    ],
)
calendar
<CallMethod 'with_columns'>
Show graph Var 'time' CallMethod 'with_columns'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

Electricity load data#

Finally we load the electricity load data. This data will both be used as a target variable but also to craft some lagged and window-aggregated features.

load_data_files = [
    data_file
    for data_file in sorted(data_source_folder.iterdir())
    if data_file.name.startswith("Total Load - Day Ahead")
    and data_file.name.endswith(".csv")
]
electricity_raw = skrub.var(
    "electricity_raw",
    pl.concat(
        [
            pl.from_pandas(pd.read_csv(data_file, na_values=["N/A", "-"])).drop(
                ["Day-ahead Total Load Forecast [MW] - BZN|FR"]
            )
            for data_file in load_data_files
        ],
        how="vertical",
    ),
)
electricity_raw
<Var 'electricity_raw'>
Show graph Var 'electricity_raw'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

electricity = (
    electricity_raw.with_columns(
        [
            pl.col("Time (UTC)")
            .str.split(by=" - ")
            .list.first()
            .str.to_datetime("%d.%m.%Y %H:%M", time_zone="UTC")
            .alias("time"),
        ]
    )
    .drop(["Time (UTC)"])
    .rename({"Actual Total Load [MW] - BZN|FR": "load_mw"})
    .filter(pl.col("time").dt.minute().eq(0))
    .filter(pl.col("time") >= time_range_start)
    .filter(pl.col("time") <= time_range_end)
    .select(["time", "load_mw"])
)
electricity
<CallMethod 'select'>
Show graph Var 'electricity_raw' CallMethod 'with_columns' CallMethod 'drop' CallMethod 'rename' CallMethod 'filter' CallMethod 'filter' CallMethod 'filter' CallMethod 'select'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

electricity.filter(pl.col("load_mw").is_null())
<CallMethod 'filter'>
Show graph Var 'electricity_raw' CallMethod 'with_columns' CallMethod 'drop' CallMethod 'rename' CallMethod 'filter' CallMethod 'filter' CallMethod 'filter' CallMethod 'select' CallMethod 'filter'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

electricity.filter(
    (pl.col("time") > pl.datetime(2021, 10, 30, hour=10, time_zone="UTC"))
    & (pl.col("time") < pl.datetime(2021, 10, 31, hour=10, time_zone="UTC"))
).skb.eval().plot.line(x="time:T", y="load_mw:Q")
electricity = electricity.with_columns([pl.col("load_mw").interpolate()])
electricity.filter(
    (pl.col("time") > pl.datetime(2021, 10, 30, hour=10, time_zone="UTC"))
    & (pl.col("time") < pl.datetime(2021, 10, 31, hour=10, time_zone="UTC"))
).skb.eval().plot.line(x="time:T", y="load_mw:Q")

Check that the number of rows matches our expectations based on the number of hours that separate the first and the last dates. We can do that by joining with the time range dataframe and checking that the number of rows stays the same.

assert (
    time.join(electricity, on="time", how="inner").shape[0] == time.shape[0]
).skb.eval()

Lagged features#

We can now create some lagged features from the electricity load data.

We will create 3 hourly lagged features, 1 daily lagged feature, and 1 weekly lagged feature. We will also create a rolling median and inter-quartile feature over the last 24 hours and over the last 7 days.

def iqr(col, *, window_size: int):
    """Inter-quartile range (IQR) of a column."""
    return col.rolling_quantile(0.75, window_size=window_size) - col.rolling_quantile(
        0.25, window_size=window_size
    )


electricity_lagged = electricity.with_columns(
    [pl.col("load_mw").shift(i).alias(f"load_mw_lag_{i}h") for i in range(1, 4)]
    + [
        pl.col("load_mw").shift(24).alias("load_mw_lag_1d"),
        pl.col("load_mw").shift(24 * 7).alias("load_mw_lag_1w"),
        pl.col("load_mw")
        .rolling_median(window_size=24)
        .alias("load_mw_rolling_median_24h"),
        pl.col("load_mw")
        .rolling_median(window_size=24 * 7)
        .alias("load_mw_rolling_median_7d"),
        iqr(pl.col("load_mw"), window_size=24).alias("load_mw_iqr_24h"),
        iqr(pl.col("load_mw"), window_size=24 * 7).alias("load_mw_iqr_7d"),
    ],
)
electricity_lagged
<CallMethod 'with_columns'>
Show graph Var 'electricity_raw' CallMethod 'with_columns' CallMethod 'drop' CallMethod 'rename' CallMethod 'filter' CallMethod 'filter' CallMethod 'filter' CallMethod 'select' CallMethod 'with_columns' CallMethod 'with_columns'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

import altair


altair.Chart(electricity_lagged.tail(100).skb.eval()).transform_fold(
    [
        "load_mw",
        "load_mw_lag_1h",
        "load_mw_lag_2h",
        "load_mw_lag_3h",
        "load_mw_lag_1d",
        "load_mw_lag_1w",
        "load_mw_rolling_median_24h",
        "load_mw_rolling_median_7d",
        "load_mw_rolling_iqr_24h",
        "load_mw_rolling_iqr_7d",
    ],
    as_=["key", "load_mw"],
).mark_line(tooltip=True).encode(x="time:T", y="load_mw:Q", color="key:N").interactive()

Investigating outliers in the lagged features#

Let’s use the skrub.TableReport tool to look at the plots of the marginal distribution of the lagged features.

from skrub import TableReport

TableReport(electricity_lagged.skb.eval())
Processing column   1 / 11
Processing column   2 / 11
Processing column   3 / 11
Processing column   4 / 11
Processing column   5 / 11
Processing column   6 / 11
Processing column   7 / 11
Processing column   8 / 11
Processing column   9 / 11
Processing column  10 / 11
Processing column  11 / 11

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

Let’s extract the dates where the inter-quartile range of the load is greater than 15,000 MW.

electricity_lagged.filter(pl.col("load_mw_iqr_7d") > 15_000)[
    "time"
].dt.date().unique().sort().to_list().skb.eval()
[datetime.date(2021, 12, 26),
 datetime.date(2021, 12, 27),
 datetime.date(2021, 12, 28),
 datetime.date(2022, 1, 7),
 datetime.date(2022, 1, 8),
 datetime.date(2023, 1, 19),
 datetime.date(2023, 1, 20),
 datetime.date(2023, 1, 21),
 datetime.date(2024, 1, 10),
 datetime.date(2024, 1, 11),
 datetime.date(2024, 1, 12),
 datetime.date(2024, 1, 13)]

We observe 3 date ranges with high inter-quartile range. Let’s plot the electricity load and the lagged features for the first data range along with the weather data for Paris.

altair.Chart(
    electricity_lagged.filter(
        (pl.col("time") > pl.datetime(2021, 12, 1, time_zone="UTC"))
        & (pl.col("time") < pl.datetime(2021, 12, 31, time_zone="UTC"))
    ).skb.eval()
).transform_fold(
    [
        "load_mw",
        "load_mw_iqr_7d",
    ],
).mark_line(
    tooltip=True
).encode(
    x="time:T", y="value:Q", color="key:N"
).interactive()
altair.Chart(
    all_city_weather.filter(
        (pl.col("time") > pl.datetime(2021, 12, 1, time_zone="UTC"))
        & (pl.col("time") < pl.datetime(2021, 12, 31, time_zone="UTC"))
    ).skb.eval()
).transform_fold(
    [f"weather_temperature_2m_{city_name}" for city_name in city_names],
).mark_line(
    tooltip=True
).encode(
    x="time:T", y="value:Q", color="key:N"
).interactive()

Based on the plots above, we can see that the electricity load was high just before the Christmas holidays due to low temperatures. Then the load suddenly dropped because temperatures went higher right at the start of the end-of-year holidays.

So those outliers do not seem to be caused to a data quality issue but rather due to a real change in the electricity load demand. We could conduct similar analysis for the other date ranges with high inter-quartile range but we will skip that for now.

If we had observed significant data quality issues over extended periods of time could have been addressed by removing the corresponding rows from the dataset. However, this would make the lagged and windowing feature engineering challenging to reimplement correctly. A better approach would be to keep a contiguous dataset assign 0 weights to the affected rows when fitting or evaluating the trained models via the use of the sample_weight parameter.

Final dataset#

We now assemble the dataset that will be used to train and evaluate the forecasting models via backtesting.

prediction_time = time = skrub.var(
    "prediction_time",
    pl.DataFrame().with_columns(
        pl.datetime_range(
            start=time_range_start + pl.duration(days=7),
            end=time_range_end - pl.duration(hours=24),
            time_zone="UTC",
            interval="1h",
        ).alias("prediction_time"),
    ),
)
prediction_time
<Var 'prediction_time'>
Show graph Var 'prediction_time'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

import polars.selectors as cs


@skrub.deferred
def build_features(
    prediction_time,
    electricity_lagged,
    all_city_weather,
    calendar,
    future_feature_horizons=[1, 24],
):

    return (
        prediction_time.join(
            electricity_lagged, left_on="prediction_time", right_on="time"
        )
        .join(
            all_city_weather.select(
                [pl.col("time")]
                + [
                    pl.col(c).shift(-h).alias(c + f"_future_{h}h")
                    for c in all_city_weather.columns
                    if c != "time"
                    for h in future_feature_horizons
                ]
            ),
            left_on="prediction_time",
            right_on="time",
        )
        .join(
            calendar.select(
                [pl.col("time")]
                + [
                    pl.col(c).shift(-h).alias(c + f"_future_{h}h")
                    for c in calendar.columns
                    if c != "time"
                    for h in future_feature_horizons
                ]
            ),
            left_on="prediction_time",
            right_on="time",
        )
    ).drop("prediction_time")


features = build_features(
    prediction_time=prediction_time,
    electricity_lagged=electricity_lagged,
    all_city_weather=all_city_weather,
    calendar=calendar,
).skb.mark_as_X()

features
<Call 'build_features'>
Show graph Var 'prediction_time' X: Call 'build_features' Var 'electricity_raw' CallMethod 'with_columns' CallMethod 'drop' CallMethod 'rename' CallMethod 'filter' CallMethod 'filter' CallMethod 'filter' CallMethod 'select' CallMethod 'with_columns' CallMethod 'with_columns' Var 'all_city_weather' Var 'time' CallMethod 'with_columns'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

horizons = range(1, 25)  # Forecasting horizons from 1 to 24 hours
horizon_of_interest = horizons[-1]  # Focus on the 24-hour horizon

target_column_name_pattern = "load_mw_horizon_{horizon}h"

targets = prediction_time.join(
    electricity.with_columns(
        [
            pl.col("load_mw")
            .shift(-h)
            .alias(target_column_name_pattern.format(horizon=h))
            for h in horizons
        ]
    ),
    left_on="prediction_time",
    right_on="time",
)
target_column_name = target_column_name_pattern.format(horizon=horizon_of_interest)
predicted_target_column_name = "predicted_" + target_column_name
target = targets[target_column_name].skb.mark_as_y()
from sklearn.ensemble import HistGradientBoostingRegressor
import skrub.selectors as s


predictions = features.skb.apply(
    skrub.DropCols(
        cols=skrub.choose_from(
            {
                "none": s.glob(""),  # No column has an empty name.
                "load": s.glob("load_*"),
                "rolling_load": s.glob("load_mw_rolling_*"),
                "weather": s.glob("weather_*"),
                "temperature": s.glob("weather_temperature_*"),
                "moisture": s.glob("weather_moisture_*"),
                "cloud_cover": s.glob("weather_cloud_cover_*"),
                "calendar": s.glob("cal_*"),
                "holiday": s.glob("cal_is_holiday*"),
                "future_1h": s.glob("*_future_1h"),
                "future_24h": s.glob("*_future_24h"),
                "non_paris_weather": s.glob("weather_*") & ~s.glob("weather_*_paris_*"),
            },
            name="dropped_features",
        )
    )
).skb.apply(
    HistGradientBoostingRegressor(
        random_state=0,
        learning_rate=skrub.choose_float(
            0.01, 1, default=0.1, log=True, name="learning_rate"
        ),
        max_leaf_nodes=skrub.choose_int(
            3, 300, default=30, log=True, name="max_leaf_nodes"
        ),
    ),
    y=target,
)
predictions
<Apply HistGradientBoostingRegressor>
Show graph Var 'prediction_time' X: Call 'build_features' CallMethod 'join' Var 'electricity_raw' CallMethod 'with_columns' CallMethod 'drop' CallMethod 'rename' CallMethod 'filter' CallMethod 'filter' CallMethod 'filter' CallMethod 'select' CallMethod 'with_columns' CallMethod 'with_columns' CallMethod 'with_columns' Var 'all_city_weather' Var 'time' CallMethod 'with_columns' Apply DropCols Apply HistGradientBoostingRegressor y: GetItem 'load_mw_horizon_24h'

Result:

Please enable javascript

The skrub table reports need javascript to display correctly. If you are displaying a report in a Jupyter notebook and you see this message, you may need to re-execute the cell or to trust the notebook (button on the top right or "File > Trust notebook").

altair.Chart(
    pl.concat(
        [
            targets.skb.eval(),
            predictions.rename(
                {target_column_name: predicted_target_column_name}
            ).skb.eval(),
        ],
        how="horizontal",
    ).tail(24 * 7)
).transform_fold(
    [target_column_name, predicted_target_column_name],
).mark_line(
    tooltip=True
).encode(
    x="prediction_time:T", y="value:Q", color="key:N"
).interactive()
from sklearn.model_selection import TimeSeriesSplit


max_train_size = 2 * 52 * 24 * 7  # max ~2 years of training data
test_size = 24 * 7 * 24  # 24 weeks of test data
gap = 7 * 24  # 1 week gap between train and test sets
ts_cv_5 = TimeSeriesSplit(
    n_splits=5, max_train_size=max_train_size, test_size=test_size, gap=gap
)

for cv_idx, (train_idx, test_idx) in enumerate(
    ts_cv_5.split(prediction_time.skb.eval())
):
    print(f"CV iteration #{cv_idx}")
    train_datetimes = prediction_time.skb.eval()[train_idx]
    test_datetimes = prediction_time.skb.eval()[test_idx]
    print(
        f"Train: {train_datetimes.shape[0]} rows, "
        f"Test: {test_datetimes.shape[0]} rows"
    )
    print(f"Train time range: {train_datetimes[0, 0]} to " f"{train_datetimes[-1, 0]} ")
    print(f"Test time range: {test_datetimes[0, 0]} to " f"{test_datetimes[-1, 0]} ")
CV iteration #0
Train: 16224 rows, Test: 4032 rows
Train time range: 2021-03-30 00:00:00+00:00 to 2023-02-03 23:00:00+00:00 
Test time range: 2023-02-11 00:00:00+00:00 to 2023-07-28 23:00:00+00:00 
CV iteration #1
Train: 17472 rows, Test: 4032 rows
Train time range: 2021-07-24 00:00:00+00:00 to 2023-07-21 23:00:00+00:00 
Test time range: 2023-07-29 00:00:00+00:00 to 2024-01-12 23:00:00+00:00 
CV iteration #2
Train: 17472 rows, Test: 4032 rows
Train time range: 2022-01-08 00:00:00+00:00 to 2024-01-05 23:00:00+00:00 
Test time range: 2024-01-13 00:00:00+00:00 to 2024-06-28 23:00:00+00:00 
CV iteration #3
Train: 17472 rows, Test: 4032 rows
Train time range: 2022-06-25 00:00:00+00:00 to 2024-06-21 23:00:00+00:00 
Test time range: 2024-06-29 00:00:00+00:00 to 2024-12-13 23:00:00+00:00 
CV iteration #4
Train: 17472 rows, Test: 4032 rows
Train time range: 2022-12-10 00:00:00+00:00 to 2024-12-06 23:00:00+00:00 
Test time range: 2024-12-14 00:00:00+00:00 to 2025-05-30 23:00:00+00:00 
from sklearn.metrics import make_scorer, mean_absolute_percentage_error, get_scorer


mape_scorer = make_scorer(mean_absolute_percentage_error)

cv_results = predictions.skb.cross_validate(
    cv=ts_cv_5,
    scoring={
        "r2": get_scorer("r2"),
        "mape": mape_scorer,
    },
    return_train_score=True,
    return_pipeline=True,
    verbose=1,
    n_jobs=-1,
)
cv_results.round(3)
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    8.5s finished
fit_time score_time test_r2 train_r2 test_mape train_mape pipeline
0 2.940 0.059 0.963 0.994 0.027 0.012 SkrubPipeline(expr=<Apply HistGradientBoosting...
1 3.261 0.059 0.978 0.994 0.024 0.013 SkrubPipeline(expr=<Apply HistGradientBoosting...
2 3.222 0.054 0.974 0.993 0.023 0.014 SkrubPipeline(expr=<Apply HistGradientBoosting...
3 3.179 0.059 0.980 0.993 0.019 0.014 SkrubPipeline(expr=<Apply HistGradientBoosting...
4 2.171 0.036 0.977 0.993 0.023 0.014 SkrubPipeline(expr=<Apply HistGradientBoosting...
def collect_cv_predictions(pipelines, cv_splitter, predictions, prediction_time):

    index_generator = cv_splitter.split(prediction_time.skb.eval())

    def splitter(X, y, index_generator):
        """Workaround to transform a scikit-learn splitter into a function understood
        by `skrub.train_test_split`."""
        train_idx, test_idx = next(index_generator)
        return X[train_idx], X[test_idx], y[train_idx], y[test_idx]

    results = []

    for (_, test_idx), pipeline in zip(
        cv_splitter.split(prediction_time.skb.eval()), pipelines
    ):
        split = predictions.skb.train_test_split(
            predictions.skb.get_data(),
            splitter=splitter,
            index_generator=index_generator,
        )
        results.append(
            pl.DataFrame(
                {
                    "prediction_time": prediction_time.skb.eval()[test_idx],
                    "load_mw": split["y_test"],
                    "predicted_load_mw": pipeline.predict(split["test"]),
                }
            )
        )
    return results
cv_predictions = collect_cv_predictions(
    cv_results["pipeline"], ts_cv_5, predictions, prediction_time
)
cv_predictions[0]
shape: (4_032, 3)
prediction_timeload_mwpredicted_load_mw
datetime[μs, UTC]f64f64
2023-02-11 00:00:00 UTC59258.059855.334418
2023-02-11 01:00:00 UTC58654.059958.654564
2023-02-11 02:00:00 UTC56155.057666.184522
2023-02-11 03:00:00 UTC54463.055832.880673
2023-02-11 04:00:00 UTC54698.057121.984097
2023-07-28 19:00:00 UTC38781.040093.987086
2023-07-28 20:00:00 UTC38455.039343.771368
2023-07-28 21:00:00 UTC39972.040738.151594
2023-07-28 22:00:00 UTC39825.039449.468131
2023-07-28 23:00:00 UTC36822.035828.293662
import numpy as np


def plot_reliability_diagram(cv_predictions, n_bins=10):
    # min and max load over all predictions and observations for any folds:
    all_loads = pl.concat(
        [
            cv_prediction.select(["load_mw", "predicted_load_mw"])
            for cv_prediction in cv_predictions
        ]
    )
    all_loads = pl.concat(all_loads["load_mw", "predicted_load_mw"])
    min_load, max_load = all_loads.min(), all_loads.max()
    scale = altair.Scale(domain=[min_load, max_load])
    chart = None
    for i, cv_predictions_i in enumerate(cv_predictions):
        mean_per_bins = (
            cv_predictions_i.group_by(
                pl.col("predicted_load_mw").qcut(np.linspace(0, 1, n_bins))
            )
            .agg(
                [
                    pl.col("load_mw").mean().alias("mean_load_mw"),
                    pl.col("predicted_load_mw").mean().alias("mean_predicted_load_mw"),
                ]
            )
            .sort("predicted_load_mw")
        )

        this_chart = (
            altair.Chart(mean_per_bins)
            .mark_line(tooltip=True)
            .encode(
                x=altair.X("mean_predicted_load_mw:Q", scale=scale),
                y=altair.Y("mean_load_mw:Q", scale=scale),
            )
        )
        if chart is None:
            chart = this_chart
        else:
            chart += this_chart
    return chart


plot_reliability_diagram(cv_predictions).interactive()
ts_cv_2 = TimeSeriesSplit(
    n_splits=2, test_size=test_size, max_train_size=max_train_size, gap=24
)
randomized_search = predictions.skb.get_randomized_search(
    cv=ts_cv_2,
    scoring="r2",
    n_iter=100,
    fitted=True,
    verbose=1,
    n_jobs=-1,
)
Fitting 2 folds for each of 100 candidates, totalling 200 fits
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[35], line 4
      1 ts_cv_2 = TimeSeriesSplit(
      2     n_splits=2, test_size=test_size, max_train_size=max_train_size, gap=24
      3 )
----> 4 randomized_search = predictions.skb.get_randomized_search(
      5     cv=ts_cv_2,
      6     scoring="r2",
      7     n_iter=100,
      8     fitted=True,
      9     verbose=1,
     10     n_jobs=-1,
     11 )

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/skrub/_expressions/_skrub_namespace.py:1736, in SkrubNamespace.get_randomized_search(self, fitted, keep_subsampling, **kwargs)
   1734 if not fitted:
   1735     return search
-> 1736 return search.fit(
   1737     env_with_subsampling(self._expr, self.get_data(), keep_subsampling)
   1738 )

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/skrub/_expressions/_estimator.py:682, in ParamSearch.fit(self, environment)
    680     search.param_distributions = param_grid
    681 X, y = _compute_Xy(self.expr, environment)
--> 682 search.fit(X, y)
    683 _copy_attr(search, self, _SKLEARN_SEARCH_FITTED_ATTRIBUTES_TO_COPY)
    684 try:

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/base.py:1363, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1356     estimator._validate_params()
   1358 with config_context(
   1359     skip_parameter_validation=(
   1360         prefer_skip_nested_validation or global_skip_validation
   1361     )
   1362 ):
-> 1363     return fit_method(estimator, *args, **kwargs)

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1051, in BaseSearchCV.fit(self, X, y, **params)
   1045     results = self._format_results(
   1046         all_candidate_params, n_splits, all_out, all_more_results
   1047     )
   1049     return results
-> 1051 self._run_search(evaluate_candidates)
   1053 # multimetric is determined here because in the case of a callable
   1054 # self.scoring the return type is only known after calling
   1055 first_test_score = all_out[0]["test_scores"]

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/model_selection/_search.py:1992, in RandomizedSearchCV._run_search(self, evaluate_candidates)
   1990 def _run_search(self, evaluate_candidates):
   1991     """Search n_iter candidates from param_distributions"""
-> 1992     evaluate_candidates(
   1993         ParameterSampler(
   1994             self.param_distributions, self.n_iter, random_state=self.random_state
   1995         )
   1996     )

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/model_selection/_search.py:997, in BaseSearchCV.fit.<locals>.evaluate_candidates(candidate_params, cv, more_results)
    989 if self.verbose > 0:
    990     print(
    991         "Fitting {0} folds for each of {1} candidates,"
    992         " totalling {2} fits".format(
    993             n_splits, n_candidates, n_candidates * n_splits
    994         )
    995     )
--> 997 out = parallel(
    998     delayed(_fit_and_score)(
    999         clone(base_estimator),
   1000         X,
   1001         y,
   1002         train=train,
   1003         test=test,
   1004         parameters=parameters,
   1005         split_progress=(split_idx, n_splits),
   1006         candidate_progress=(cand_idx, n_candidates),
   1007         **fit_and_score_kwargs,
   1008     )
   1009     for (cand_idx, parameters), (split_idx, (train, test)) in product(
   1010         enumerate(candidate_params),
   1011         enumerate(cv.split(X, y, **routed_params.splitter.split)),
   1012     )
   1013 )
   1015 if len(out) < 1:
   1016     raise ValueError(
   1017         "No fits were performed. "
   1018         "Was the CV iterator empty? "
   1019         "Were there no candidates?"
   1020     )

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/sklearn/utils/parallel.py:82, in Parallel.__call__(self, iterable)
     73 warning_filters = warnings.filters
     74 iterable_with_config_and_warning_filters = (
     75     (
     76         _with_config_and_warning_filters(delayed_func, config, warning_filters),
   (...)     80     for delayed_func, args, kwargs in iterable
     81 )
---> 82 return super().__call__(iterable_with_config_and_warning_filters)

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
   2066 # The first item from the output is blank, but it makes the interpreter
   2067 # progress until it enters the Try/Except block of the generator and
   2068 # reaches the first `yield` statement. This starts the asynchronous
   2069 # dispatch of the tasks to the workers.
   2070 next(output)
-> 2072 return output if self.return_generator else list(output)

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1679     yield
   1681     with self._backend.retrieval_context():
-> 1682         yield from self._retrieve()
   1684 except GeneratorExit:
   1685     # The generator has been garbage collected before being fully
   1686     # consumed. This aborts the remaining tasks if possible and warn
   1687     # the user if necessary.
   1688     self._exception = True

File ~/work/forecasting/forecasting/.pixi/envs/doc/lib/python3.12/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
   1789 if self.return_ordered:
   1790     # Case ordered: wait for completion (or error) of the next job
   1791     # that have been dispatched and not retrieved yet. If no job
   (...)   1795     # control only have to be done on the amount of time the next
   1796     # dispatched job is pending.
   1797     if (nb_jobs == 0) or (
   1798         self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
   1799     ):
-> 1800         time.sleep(0.01)
   1801         continue
   1803 elif nb_jobs == 0:
   1804     # Case unordered: jobs are added to the list of jobs to
   1805     # retrieve `self._jobs` only once completed or in error, which
   (...)   1811     # timeouts before any other dispatched job has completed and
   1812     # been added to `self._jobs` to be retrieved.

KeyboardInterrupt: 
randomized_search.results_.round(3)
randomized_search.plot_results().update_layout(margin=dict(l=150))
# nested_cv_results = skrub.cross_validate(
#     environment=predictions.skb.get_data(),
#     pipeline=randomized_search,
#     cv=ts_cv_5,
#     scoring={
#         "r2": get_scorer("r2"),
#         "mape": mape_scorer,
#     },
#     n_jobs=-1,
#     return_pipeline=True,
# ).round(3)
# nested_cv_results
# for outer_cv_idx in range(len(nested_cv_results)):
#     print(
#         nested_cv_results.loc[outer_cv_idx, "pipeline"]
#         .results_.loc[0]
#         .round(3)
#         .to_dict()
#     )
# from joblib import Parallel, delayed

# cv_predictions = []
# for ts_cv_train_idx, ts_cv_test_idx in ts_cv_5.split(prediction_time.skb.eval()):
#     features[ts_cv_train_idx].fit
from sklearn.multioutput import MultiOutputRegressor

model = MultiOutputRegressor(
    estimator=HistGradientBoostingRegressor(random_state=0), n_jobs=-1
)
multioutput_predictions = features.skb.apply(
    model, y=targets.skb.drop(cols=["prediction_time", "load_mw"]).skb.mark_as_y()
).skb.set_name("multioutput_gbdt")
target_column_names = [target_column_name_pattern.format(horizon=h) for h in horizons]
predicted_target_column_names = [
    f"predicted_{target_column_name}" for target_column_name in target_column_names
]
named_predictions = multioutput_predictions.rename(
    {k: v for k, v in zip(target_column_names, predicted_target_column_names)}
)
import datetime


def plot_horizon_forecast(
    targets, named_predictions, plot_at_time, historical_timedelta
):
    """Plot the true target and the forecast values."""
    merged_data = targets.skb.select(cols=["prediction_time", "load_mw"]).skb.concat(
        [named_predictions], axis=1
    )
    start_time = plot_at_time - historical_timedelta
    end_time = plot_at_time + datetime.timedelta(
        hours=named_predictions.skb.eval().shape[1]
    )
    true_values_past = merged_data.filter(
        pl.col("prediction_time").is_between(start_time, plot_at_time, closed="both")
    ).rename({"load_mw": "Past true load"})
    true_values_future = merged_data.filter(
        pl.col("prediction_time").is_between(plot_at_time, end_time, closed="both")
    ).rename({"load_mw": "Future true load"})
    predicted_record = (
        merged_data.skb.select(
            cols=skrub.selectors.filter_names(str.startswith, "predict")
        )
        .row(by_predicate=pl.col("prediction_time") == plot_at_time, named=True)
        .skb.eval()
    )
    forecast_values = pl.DataFrame(
        {
            "prediction_time": predicted_record["prediction_time"]
            + datetime.timedelta(hours=horizon),
            "Forecast load": predicted_record[
                "predicted_" + target_column_name_pattern.format(horizon=horizon)
            ],
        }
        for horizon in range(1, len(predicted_record))
    )

    true_values_past_chart = (
        altair.Chart(true_values_past.skb.eval())
        .transform_fold(["Past true load"])
        .mark_line(tooltip=True)
        .encode(x="prediction_time:T", y="Past true load:Q", color="key:N")
    )
    true_values_future_chart = (
        altair.Chart(true_values_future.skb.eval())
        .transform_fold(["Future true load"])
        .mark_line(tooltip=True)
        .encode(x="prediction_time:T", y="Future true load:Q", color="key:N")
    )
    forecast_values_chart = (
        altair.Chart(forecast_values)
        .transform_fold(["Forecast load"])
        .mark_line(tooltip=True)
        .encode(x="prediction_time:T", y="Forecast load:Q", color="key:N")
    )
    return (
        true_values_past_chart + true_values_future_chart + forecast_values_chart
    ).interactive()
plot_at_time = datetime.datetime(2025, 5, 24, 0, 0, tzinfo=datetime.timezone.utc)
historical_timedelta = datetime.timedelta(hours=24 * 5)
plot_horizon_forecast(targets, named_predictions, plot_at_time, historical_timedelta)
plot_at_time = datetime.datetime(2025, 5, 25, 0, 0, tzinfo=datetime.timezone.utc)
plot_horizon_forecast(targets, named_predictions, plot_at_time, historical_timedelta)
from sklearn.metrics import r2_score


def multioutput_scorer(regressor, X, y, score_func, score_name):
    y_pred = regressor.predict(X)
    return {
        f"{score_name}_horizon_{h}h": score
        for h, score in enumerate(
            score_func(y, y_pred, multioutput="raw_values"), start=1
        )
    }


def scoring(regressor, X, y):
    return {
        **multioutput_scorer(regressor, X, y, mean_absolute_percentage_error, "mape"),
        **multioutput_scorer(regressor, X, y, r2_score, "r2"),
    }


multioutput_cv_results = multioutput_predictions.skb.cross_validate(
    cv=ts_cv_5,
    scoring=scoring,
    return_train_score=True,
    verbose=1,
    n_jobs=-1,
).round(3)
multioutput_cv_results
import itertools
from IPython.display import display

for metric_name, dataset_type in itertools.product(["mape", "r2"], ["train", "test"]):
    columns = multioutput_cv_results.columns[
        multioutput_cv_results.columns.str.startswith(f"{dataset_type}_{metric_name}")
    ]
    data_to_plot = multioutput_cv_results[columns]
    data_to_plot.columns = [
        col.replace(f"{dataset_type}_", "")
        .replace(f"{metric_name}_", "")
        .replace("_", " ")
        for col in columns
    ]

    data_long = data_to_plot.melt(var_name="horizon", value_name="score")
    chart = (
        altair.Chart(
            data_long,
            title=f"{dataset_type.title()} {metric_name.upper()} Scores by Horizon",
        )
        .mark_boxplot(extent="min-max")
        .encode(
            x=altair.X(
                "horizon:N",
                title="Horizon",
                sort=altair.Sort(
                    [f"horizon {h}h" for h in range(1, data_to_plot.shape[1])]
                ),
            ),
            y=altair.Y("score:Q", title=f"{metric_name.upper()} Score"),
            color=altair.Color("horizon:N", legend=None),
        )
    )

    display(chart)